/*****************************************************************************
Filename    : rsa.c
Author      : Terrantsh (tanshanhe@foxmail.com)
Date        : 2018-8-31 10:31:23
Description : RSA加密函数
*****************************************************************************/
#include <string.h>
#include <stdio.h>
#include <time.h>
#include <stdlib.h>

#include "RSA2048.h"
#include "RSA2048Bignum.h"
#include "RSA2048Keys.h"
#include "hal_trace.h"

void generate_rand(uint8_t *block, uint32_t block_len)
{
    uint32_t i;
    for(i=0; i<block_len; i++) {
        srand ((unsigned)time(NULL));
        block[i] = rand();
    }
}

static int public_block_operation(uint8_t *out, uint32_t *out_len, uint8_t *in, uint32_t in_len, rsa_pk_t *pk);
static int private_block_operation(uint8_t *out, uint32_t *out_len, uint8_t *in, uint32_t in_len, rsa_sk_t *sk);

int rsa_public_encrypt(uint8_t *out, uint32_t *out_len, uint8_t *in, uint32_t in_len, rsa_pk_t *pk)
{
    int status;
    uint8_t byte, pkcs_block[RSA_MAX_MODULUS_LEN];
    uint32_t i, modulus_len;

    modulus_len = (pk->bits + 7) / 8;
    if(in_len + 11 > modulus_len) {
        return ERR_WRONG_LEN;
    }

    pkcs_block[0] = 0;
    pkcs_block[1] = 2;
    for(i=2; i<modulus_len-in_len-1; i++) {
        do {
            generate_rand(&byte, 1);
        } while(byte == 0);
        pkcs_block[i] = byte;
    }

    pkcs_block[i++] = 0;

    memcpy((uint8_t *)&pkcs_block[i], (uint8_t *)in, in_len);
    status = public_block_operation(out, out_len, pkcs_block, modulus_len, pk);

    // Clear potentially sensitive information
    byte = 0;
    memset((uint8_t *)pkcs_block, 0, sizeof(pkcs_block));

    return status;
}

int rsa_public_decrypt(uint8_t *out, uint32_t *out_len, uint8_t *in, uint32_t in_len, rsa_pk_t *pk)
{
    int status;
    uint8_t pkcs_block[RSA_MAX_MODULUS_LEN];
    uint32_t i, modulus_len, pkcs_block_len;

    modulus_len = (pk->bits + 7) / 8;
    if(in_len > modulus_len)
        return ERR_WRONG_LEN;

    status = public_block_operation(pkcs_block, &pkcs_block_len, in, in_len, pk);
    if(status != 0)
        return status;

    if(pkcs_block_len != modulus_len)
        return ERR_WRONG_LEN;

    if((pkcs_block[0] != 0) || (pkcs_block[1] != 1))
        return ERR_WRONG_DATA;

    for(i=2; i<modulus_len-1; i++) {
        if(pkcs_block[i] != 0xFF)   break;
    }

    if(pkcs_block[i++] != 0)
        return ERR_WRONG_DATA;

    *out_len = modulus_len - i;
    if(*out_len + 11 > modulus_len)
        return ERR_WRONG_DATA;

    memcpy((uint8_t *)out, (uint8_t *)&pkcs_block[i], *out_len);

    // Clear potentially sensitive information
    memset((uint8_t *)pkcs_block, 0, sizeof(pkcs_block));

    return status;
}

int rsa_private_encrypt(uint8_t *out, uint32_t *out_len, uint8_t *in, uint32_t in_len, rsa_sk_t *sk)
{
    int status;
    uint8_t pkcs_block[RSA_MAX_MODULUS_LEN];
    uint32_t i, modulus_len;

    modulus_len = (sk->bits + 7) / 8;
    if(in_len + 11 > modulus_len)
        return ERR_WRONG_LEN;

    pkcs_block[0] = 0;
    pkcs_block[1] = 1;
    for(i=2; i<modulus_len-in_len-1; i++) {
        pkcs_block[i] = 0xFF;
    }

    pkcs_block[i++] = 0;

    memcpy((uint8_t *)&pkcs_block[i], (uint8_t *)in, in_len);

    status = private_block_operation(out, out_len, pkcs_block, modulus_len, sk);

    // Clear potentially sensitive information
    memset((uint8_t *)pkcs_block, 0, sizeof(pkcs_block));

    return status;
}

int rsa_private_decrypt(uint8_t *out, uint32_t *out_len, uint8_t *in, uint32_t in_len, rsa_sk_t *sk)
{
    int status;
    uint8_t pkcs_block[RSA_MAX_MODULUS_LEN];
    uint32_t i, modulus_len, pkcs_block_len;
    TRACE(0,"rsa_private_decrypt enter\n");
    modulus_len = (sk->bits + 7) / 8;
    if(in_len > modulus_len)
        return ERR_WRONG_LEN;

    status = private_block_operation(pkcs_block, &pkcs_block_len, in, in_len, sk);
    if(status != 0)
        return status;

    if(pkcs_block_len != modulus_len)
        return ERR_WRONG_LEN;

    if((pkcs_block[0] != 0) || (pkcs_block[1] != 2))
        return ERR_WRONG_DATA;

    for(i=2; i<modulus_len-1; i++) {
        if(pkcs_block[i] == 0)  break;
    }

    i++;
    if(i >= modulus_len)
        return ERR_WRONG_DATA;
    *out_len = modulus_len - i;
    if(*out_len + 11 > modulus_len)
        return ERR_WRONG_DATA;
    memcpy((uint8_t *)out, (uint8_t *)&pkcs_block[i], *out_len);
    // Clear potentially sensitive information
    memset((uint8_t *)pkcs_block, 0, sizeof(pkcs_block));
    TRACE(0,"rsa_private_decrypt exit\n");
    return status;
}

static int public_block_operation(uint8_t *out, uint32_t *out_len, uint8_t *in, uint32_t in_len, rsa_pk_t *pk)
{
    uint32_t edigits, ndigits;
    bn_t c[BN_MAX_DIGITS], e[BN_MAX_DIGITS], m[BN_MAX_DIGITS], n[BN_MAX_DIGITS];

    bn_decode(m, BN_MAX_DIGITS, in, in_len);
    bn_decode(n, BN_MAX_DIGITS, pk->modulus, RSA_MAX_MODULUS_LEN);
    bn_decode(e, BN_MAX_DIGITS, pk->exponent, RSA_MAX_MODULUS_LEN);

    ndigits = bn_digits(n, BN_MAX_DIGITS);
    edigits = bn_digits(e, BN_MAX_DIGITS);

    if(bn_cmp(m, n, ndigits) >= 0) {
        return ERR_WRONG_DATA;
    }

    bn_mod_exp(c, m, e, edigits, n, ndigits);

    *out_len = (pk->bits + 7) / 8;
    bn_encode(out, *out_len, c, ndigits);

    // Clear potentially sensitive information
    memset((uint8_t *)c, 0, sizeof(c));
    memset((uint8_t *)m, 0, sizeof(m));

    return 0;
}

static int private_block_operation(uint8_t *out, uint32_t *out_len, uint8_t *in, uint32_t in_len, rsa_sk_t *sk)
{
    TRACE(0,"private_block_operation enter\n");
    uint32_t cdigits, ndigits, pdigits;
    bn_t c[BN_MAX_DIGITS], cp[BN_MAX_DIGITS], cq[BN_MAX_DIGITS];
    bn_t dp[BN_MAX_DIGITS], dq[BN_MAX_DIGITS], mp[BN_MAX_DIGITS], mq[BN_MAX_DIGITS];
    bn_t n[BN_MAX_DIGITS], p[BN_MAX_DIGITS], q[BN_MAX_DIGITS], q_inv[BN_MAX_DIGITS], t[BN_MAX_DIGITS];

    bn_decode(c, BN_MAX_DIGITS, in, in_len);
    bn_decode(n, BN_MAX_DIGITS, sk->modulus, RSA_MAX_MODULUS_LEN);
    bn_decode(p, BN_MAX_DIGITS, sk->prime1, RSA_MAX_PRIME_LEN);
    bn_decode(q, BN_MAX_DIGITS, sk->prime2, RSA_MAX_PRIME_LEN);
    bn_decode(dp, BN_MAX_DIGITS, sk->prime_exponent1, RSA_MAX_PRIME_LEN);
    bn_decode(dq, BN_MAX_DIGITS, sk->prime_exponent2, RSA_MAX_PRIME_LEN);
    bn_decode(q_inv, BN_MAX_DIGITS, sk->coefficient, RSA_MAX_PRIME_LEN);

    cdigits = bn_digits(c, BN_MAX_DIGITS);
    ndigits = bn_digits(n, BN_MAX_DIGITS);
    pdigits = bn_digits(p, BN_MAX_DIGITS);

    if(bn_cmp(c, n, ndigits) >= 0)
        return ERR_WRONG_DATA;

    bn_mod(cp, c, cdigits, p, pdigits);
    bn_mod(cq, c, cdigits, q, pdigits);
    bn_mod_exp(mp, cp, dp, pdigits, p, pdigits);
    bn_assign_zero(mq, ndigits);
    bn_mod_exp(mq, cq, dq, pdigits, q, pdigits);

    if(bn_cmp(mp, mq, pdigits) >= 0) {
        bn_sub(t, mp, mq, pdigits);
    } else {
        bn_sub(t, mq, mp, pdigits);
        bn_sub(t, p, t, pdigits);
    }

    bn_mod_mul(t, t, q_inv, p, pdigits);
    bn_mul(t, t, q, pdigits);
    bn_add(t, t, mq, ndigits);

    *out_len = (sk->bits + 7) / 8;
    bn_encode(out, *out_len, t, ndigits);

    // Clear potentially sensitive information
    memset((uint8_t *)c, 0, sizeof(c));
    memset((uint8_t *)cp, 0, sizeof(cp));
    memset((uint8_t *)cq, 0, sizeof(cq));
    memset((uint8_t *)dp, 0, sizeof(dp));
    memset((uint8_t *)dq, 0, sizeof(dq));
    memset((uint8_t *)mp, 0, sizeof(mp));
    memset((uint8_t *)mq, 0, sizeof(mq));
    memset((uint8_t *)p, 0, sizeof(p));
    memset((uint8_t *)q, 0, sizeof(q));
    memset((uint8_t *)q_inv, 0, sizeof(q_inv));
    memset((uint8_t *)t, 0, sizeof(t));

    TRACE(0,"private_block_operation exit\n");
    return 0;
}

/**
  * @brief  打印数组
  * @param  str 数组地址 
  * @param  len 数组长度
  * @return none
  * @note   none
  */
static void TRACE1(uint8_t *str, uint32_t len)
{
    for (int i = 0; i < len; i++)
    {
        if ((i)%16==0)
            printf("\n");
        printf("0x%02x ",str[i]);
    }
    printf("\n\n");
}

/*
 * RSA2048 encrypt and decrypt
 * include rsa.c/bignum.c/rsa.h/bignum.h/keys.h
 */
// static int RSA2048(void){
//     rsa_pk_t pk = {0};
//     rsa_sk_t sk = {0};
//     uint8_t output[256];

//     // message to encrypt
//     uint8_t input [256] = { 0x21,0x55,0x53,0x53,0x53,0x53};

//     unsigned char msg [256];
//     uint32_t outputLen, msg_len;
//     uint8_t  inputLen;

//     // copy keys.h message about public key and private key to the flash RAM
//     pk.bits = KEY_M_BITS;
//     memcpy(&pk.modulus         [RSA_MAX_MODULUS_LEN-sizeof(key_m) ],  key_m,  sizeof(key_m ));
//     memcpy(&pk.exponent        [RSA_MAX_MODULUS_LEN-sizeof(key_e) ],  key_e,  sizeof(key_e ));
//     sk.bits = KEY_M_BITS;
//     memcpy(&sk.modulus         [RSA_MAX_MODULUS_LEN-sizeof(key_m) ],  key_m,  sizeof(key_m ));
//     memcpy(&sk.public_exponet  [RSA_MAX_MODULUS_LEN-sizeof(key_e) ],  key_e,  sizeof(key_e ));
//     memcpy(&sk.exponent        [RSA_MAX_MODULUS_LEN-sizeof(key_pe)],  key_pe, sizeof(key_pe));
//     memcpy(&sk.prime1          [RSA_MAX_PRIME_LEN - sizeof(key_p1)],  key_p1, sizeof(key_p1));
//     memcpy(&sk.prime2          [RSA_MAX_PRIME_LEN - sizeof(key_p2)],  key_p2, sizeof(key_p2));
//     memcpy(&sk.prime_exponent1 [RSA_MAX_PRIME_LEN - sizeof(key_e1)],  key_e1, sizeof(key_e1));
//     memcpy(&sk.prime_exponent2 [RSA_MAX_PRIME_LEN - sizeof(key_e2)],  key_e2, sizeof(key_e2));
//     memcpy(&sk.coefficient     [RSA_MAX_PRIME_LEN - sizeof(key_c) ],  key_c,  sizeof(key_c ));

//     inputLen = strlen((const char*)input);

//     // public key encrypt
//     rsa_public_encrypt(output, &outputLen, input, inputLen, &pk);
//     printf("public key encrypt is %d\n",outputLen);
//     TRACE1(output,outputLen);

//     // private key decrypt
//     rsa_private_decrypt(msg, &msg_len, output, outputLen, &sk);
//     printf("private key decrypt  len is %d\n",msg_len);
//     TRACE1(msg,msg_len);

//     // private key encrypt
//     rsa_private_encrypt(output, &outputLen, input, inputLen, &sk);
//     printf("private key encrypt len is %d\n",outputLen);
//     TRACE1(output,outputLen);

//     // public key decrypted
//     rsa_public_decrypt(msg, &msg_len, output, outputLen, &pk);
//     printf("public key decrypted len is %d\n",msg_len);
//     TRACE1(msg,msg_len);

//     return 0;
// }
/* RSA2048 function ended */

/**
  * @brief  RSA2048加密
  * @param  plaintext 明文
  * @param  plaintext_Len 明文长度
  * @param  ciphertext 密文
  * @param  ciphertext_Len 密文长度
  * @return none
  * @note   none
  */
void RSA2048EnCode(uint8_t plaintext[256], uint32_t plaintext_Len, uint8_t ciphertext[256], uint32_t ciphertext_Len)
{
    rsa_pk_t pk = {0};

    // copy keys.h message about public key and private key to the flash RAM
    pk.bits = KEY_M_BITS;
    memcpy(&pk.modulus         [RSA_MAX_MODULUS_LEN-sizeof(key_m) ],  key_m,  sizeof(key_m ));
    memcpy(&pk.exponent        [RSA_MAX_MODULUS_LEN-sizeof(key_e) ],  key_e,  sizeof(key_e ));

    //plaintext
    printf("plaintext is:");
    TRACE1(plaintext,plaintext_Len);

    // public key encrypt
    rsa_public_encrypt(ciphertext, &ciphertext_Len, plaintext, plaintext_Len, &pk);
    printf("public key encrypt is %d\n",ciphertext_Len);
    TRACE1(ciphertext,ciphertext_Len);
}

/**
  * @brief  RSA2048解密
  * @param  ciphertext 密文
  * @param  ciphertext_Len 密文长度
  * @param  plaintext 明文
  * @param  plaintext_Len 明文长度
  * @return none
  * @note   none
  */
void RSA2048DeCode(uint8_t ciphertext[256], uint32_t ciphertext_Len, uint8_t plaintext[256], uint32_t plaintext_Len)
{
    TRACE(0,"RSA2048DeCode enter\n");
    rsa_sk_t sk = {0};
    TRACE(0,"RSA2048DeCode enter1\n");
    // copy keys.h message about public key and private key to the flash RAM
    sk.bits = KEY_M_BITS;
    memcpy(&sk.modulus         [RSA_MAX_MODULUS_LEN-sizeof(key_m) ],  key_m,  sizeof(key_m ));
    memcpy(&sk.public_exponet  [RSA_MAX_MODULUS_LEN-sizeof(key_e) ],  key_e,  sizeof(key_e ));
    memcpy(&sk.exponent        [RSA_MAX_MODULUS_LEN-sizeof(key_pe)],  key_pe, sizeof(key_pe));
    memcpy(&sk.prime1          [RSA_MAX_PRIME_LEN - sizeof(key_p1)],  key_p1, sizeof(key_p1));
    memcpy(&sk.prime2          [RSA_MAX_PRIME_LEN - sizeof(key_p2)],  key_p2, sizeof(key_p2));
    memcpy(&sk.prime_exponent1 [RSA_MAX_PRIME_LEN - sizeof(key_e1)],  key_e1, sizeof(key_e1));
    memcpy(&sk.prime_exponent2 [RSA_MAX_PRIME_LEN - sizeof(key_e2)],  key_e2, sizeof(key_e2));
    memcpy(&sk.coefficient     [RSA_MAX_PRIME_LEN - sizeof(key_c) ],  key_c,  sizeof(key_c ));
    TRACE(0,"RSA2048DeCode enter2\n");
    // private key decrypt
    rsa_private_decrypt(plaintext, &plaintext_Len, ciphertext, ciphertext_Len, &sk);
    TRACE(0,"private key decrypt  len is %d\n",plaintext_Len);
    TRACE1(plaintext,plaintext_Len);
    TRACE(0,"RSA2048DeCode exit\n");
}

// int main(int argc, char const *argv[])
// {
//     // message to encrypt
//     uint8_t inputkey [256] = { 0xad, 0xfe, 0x12, 0x73, 0x3a, 0x16, 0xc7, 0xc7, 0x02, 0xed, 0xcf, 0x00, 0xb4, 0x5b, 0xde, 0xbe }; //AES128 EnCode Key
//     uint32_t inputkey_Len = 16; //input key len
//     uint8_t outputkey[256];
//     uint32_t outputkey_Len = 256;
//     uint8_t msgkey [256];
//     uint32_t msgkey_Len = 16;

//     clock_t start, finish;
//     double  duration;
//     start = clock();    // init start time
//     if (0)
//     {
//         RSA2048(); //test running   
//     }
//     else
//     {
//         RSA2048EnCode(inputkey,inputkey_Len,outputkey,outputkey_Len);
//         RSA2048DeCode(outputkey,outputkey_Len,msgkey,msgkey_Len);
//     }
//     finish = clock();   // print end time
//     duration = (double)(finish - start) / CLOCKS_PER_SEC;   // print encrypt and decrypt time
//     printf( "%f seconds\n", duration );
//     return 0;
// }
